{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "决策树的准确率：0.694\n"
     ]
    }
   ],
   "source": [
    "#构建并训练决策树分类器，这里特征选择标准使用基尼指数，树的最大深度为1\n",
    "base_model = DecisionTreeClassifier(max_depth=1, criterion='gini',random_state=1).fit(X_train, y_train)\n",
    "y_pred = base_model.predict(X_test)#对训练集进行预测\n",
    "print(f\"决策树的准确率：{accuracy_score(y_test,y_pred):.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "所有特征：['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline']\n",
      "训练数据量：142，测试数据量：36\n",
      "准确率：0.97\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import AdaBoostClassifier\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.datasets import load_wine\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import metrics\n",
    "import pandas as pd\n",
    "\n",
    "wine = load_wine()#使用葡萄酒数据集\n",
    "print(f\"所有特征：{wine.feature_names}\")\n",
    "X = pd.DataFrame(wine.data, columns=wine.feature_names)\n",
    "y = pd.Series(wine.target)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=1)\n",
    "\n",
    "\n",
    "print(f\"训练数据量：{len(X_train)}，测试数据量：{len(X_test)}\")\n",
    "\n",
    "# 定义模型，这里最大分类器数量为50，学习率为1.5\n",
    "model = AdaBoostClassifier(base_estimator=base_model,n_estimators=50,learning_rate=0.8)\n",
    "# 训练\n",
    "model.fit(X_train, y_train) \n",
    "# 预测\n",
    "y_pred = model.predict(X_test) \n",
    "acc = metrics.accuracy_score(y_test, y_pred) # 准确率\n",
    "print(f\"准确率：{acc:.2}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用GridSearchCV自动调参"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最优超参数: {'learning_rate': 0.8, 'n_estimators': 42}\n"
     ]
    }
   ],
   "source": [
    "hyperparameter_space = {'n_estimators':list(range(2, 102, 2)), \n",
    "                        'learning_rate':[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]}\n",
    "\n",
    "\n",
    "# 使用准确率为标准，将得到的准确率最高的参数输出，cv=5表示交叉验证参数，这里使用五折交叉验证，n_jobs=-1表示并行数和cpu一致\n",
    "gs = GridSearchCV(AdaBoostClassifier(\n",
    "                                     algorithm='SAMME.R',\n",
    "                                     random_state=1),\n",
    "                  param_grid=hyperparameter_space, \n",
    "                  scoring=\"accuracy\", n_jobs=-1, cv=5)\n",
    "\n",
    "gs.fit(X_train, y_train)\n",
    "print(\"最优超参数:\", gs.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "47dfafb046f9703cd15ca753999a7e7c95274099825c7bcc45b473d6496cd1b0"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
